
#include "scatter_reduce_tiling.h"
#include "register/op_def_registry.h"


namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
    // 获取属性的数量和值
    const auto * runtime_attrs = context->GetAttrs();
    size_t attr_num = runtime_attrs->GetAttrNum();
    const int64_t *dim0 = runtime_attrs->GetInt(0);
    int dim = *dim0;
    const char *reduce = runtime_attrs->GetStr(1);
    bool include_self = true;
    if(attr_num == 3) {
        const bool *include_self0 = runtime_attrs->GetBool(2);
        include_self = *include_self0;
    }
    // char*的reduce改成整数
    int reduce_type = 0;
    if (strcmp(reduce, "sum") == 0) {
        reduce_type = 0;
    } else if (strcmp(reduce, "prod") == 0) {
        reduce_type = 1;
    } else if (strcmp(reduce, "mean") == 0) {
        reduce_type = 2;
    } else if (strcmp(reduce, "amax") == 0) {
        reduce_type = 3;
    } else if (strcmp(reduce, "amin") == 0) {
        reduce_type = 4;
    } else {
        // cout << "Reduce: " << reduce << " not supported" << endl;
        return ge::GRAPH_FAILED;
    }
    // cout << "Attr Num: " << attr_num << endl;
    // cout << "Dim: " << dim << endl;
    // cout << "Reduce: " << reduce << endl;
    // cout << "Reduce Type: " << reduce_type << endl;
    // cout << "Include Self: " << include_self << endl;

    // 获取输入tensor的形状
    auto shape_x = context->GetInputShape(0)->GetOriginShape();
    // 有多少个维度
    int dimNum = shape_x.GetDimNum();
    // 每个维度的大小，用数组记录
    int shape[dimNum];
    for (int i = 0; i < dimNum; i++) {
        shape[i] = shape_x.GetDim(i);
    }
    // cout << "DimNum: " << dimNum << endl;
    // cout << "Shape: ";
    // for (int i = 0; i < dimNum; i++) {
    //     cout << shape[i] << " ";
    // }
    // cout << endl;

    // 要记录的：beforeNum, afterNum, calcNum，分别是维度dim之前的元素个数，之后的元素个数，以及每次计算的元素个数
    int beforeNum = 1;
    for (int i = 0; i < dim; i++) {
        beforeNum *= shape[i];
    }
    // 每次多少个元素进行ReduceMax
    int calcNum = shape[dim];
    // 每次ReduceMax选择的时候要跳过多少个元素
    int afterNum = 1;
    for (int i = dim + 1; i < dimNum; i++) {
        afterNum *= shape[i];
    }
    // 总共有多少个元素
    int totalNum = 1;
    for (int i = 0; i < dimNum; i++) {
        totalNum *= shape[i];
    }
    // cout << "beforeNum: " << beforeNum << endl;
    // cout << "calcNum: " << calcNum << endl;
    // cout << "afterNum: " << afterNum << endl;
    // cout << "Total Num: " << totalNum << endl;

    ScatterReduceTilingData tiling;

    tiling.set_beforeNum(beforeNum);
    tiling.set_calcNum(calcNum);
    tiling.set_afterNum(afterNum);
    tiling.set_totalNum(totalNum);
    
    tiling.set_dim(dim);
    tiling.set_reduce(reduce_type);
    tiling.set_include_self((int)include_self);


    // 每个核计算blockCalcNum个calcNum，一共blockNum个核，4个LocalTensor，数据类型每个占4字节，UB上限196352字节
    // UB容量：blockCalcNum * calcNum * 4 * 4 <= 192000
    // blockCalcNum <= 192000 / 4 / 4 / calcNum
    // 计算数量：blockCalcNum * blockNum >= beforeNum * afterNum
    // blockNum >= beforeNum * afterNum / blockCalcNum
    // blockNum >= beforeNum * afterNum / (192000 / 4 / 4 / calcNum)

    // int blockCalcNumMax = 192000 / 4 / 4 / calcNum;
    // int blockNum = (beforeNum * afterNum + blockCalcNumMax - 1) / blockCalcNumMax;
    // context->SetBlockDim(blockNum);

    // 获取输入tensor的数据类型
    auto tensor = context->GetInputTensor(0);
    auto type = tensor->GetDataType();

    if (dim == 0 && reduce_type == 4 && include_self == 0 && type==ge::DT_FLOAT) {
        context->SetTilingKey(0);
        context->SetBlockDim(40);
    } else {
        context->SetTilingKey(1);
        context->SetBlockDim(beforeNum * afterNum);
    }

    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());

    return ge::GRAPH_SUCCESS;
}
}


namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    const gert::Shape* x1_shape = context->GetInputShape(0);
    gert::Shape* y_shape = context->GetOutputShape(0);
    *y_shape = *x1_shape;
    return GRAPH_SUCCESS;
}
}


namespace ops {
class ScatterReduce : public OpDef {
public:
    explicit ScatterReduce(const char* name) : OpDef(name)
    {
        this->Input("self")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Input("index")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32, ge::DT_INT32})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Input("src")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Output("y")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT, ge::DT_FLOAT16})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
        this->Attr("dim").Int();
        this->Attr("reduce").String();
        this->Attr("include_self").AttrType(OPTIONAL).Bool(true);

        this->SetInferShape(ge::InferShape);

        this->AICore()
            .SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend910b");

    }
};

OP_ADD(ScatterReduce);
}
